from __future__ import division, print_function

#
# I want this to provide a reference implementation of the FFT_DI
# method for propagating forward a wavefront defined at a specific
# plane, based on the Helmholtz-Kirchoff formulation.
#

#
# The method is due to Shen and Wang (2006), doi:10.1364/AO.45.001102
#

import numpy as np
from scipy.fftpack import next_fast_len
import pyfftw
import numba
import gc

def min_z(U_0, step, wavelength):
    """This function returns the minimum z for which the propagation
    will have a chance of avoiding numerical errors due to the nyquist
    sampling limit. The calculation is based on Mehrabkhani and
    Schneider (2017) doi:10.1364/OE.25.030229."""
    x_limit = np.sqrt((4 * step[0]**2 / wavelength - 1) * (U_0.shape[0] * step[0])**2)
    y_limit = np.sqrt((4 * step[1]**2 / wavelength - 1) * (U_0.shape[1] * step[1])**2)
    return np.max([0, x_limit, y_limit])



@numba.njit(parallel=True)
def populate_G(G, z, wavelength, step, shape):
    """This function populates the large matrix G in real space,
    as defined in Shen and Wang. It takes the matrix to populate
    G, the propagation distance z, the light wavelength, the 
    stepsize of the array in real space, and the shape of the
    original light-field array (because the size of G will include
    the zero-padding required to both fix the circular convolution
    issues and reach a "good" fft number)."""

    ik = 2j * np.pi/wavelength
    # this is the appropriate way to deal with reverse propagation
    if z <0:
        ik = -ik

    prefactor = (step[0] * step[1] / (2 * np.pi)) * np.abs(z)

    for i in numba.prange(G.shape[0]):
        x = (-shape[0]+1 + i) * step[0]
        for j in numba.prange(G.shape[1]):
            y = (-shape[1]+1 + j) * step[1]
            r = np.sqrt(x**2 + y**2 + z**2)
            #G[i,j] = prefactor * np.exp(ik * r) * \
            #      r * (1 - ik*r)
            G[i,j] = prefactor * np.exp(ik * r) * \
                  (1/r - ik) / r**2



@numba.njit(parallel=True)
def multiply_inplace(A,B):
    """A simple in-place multiply algorithm, because for some
    reason numpy's multiply algorithm allocates extra memory
    beyond the two arrays multiplied."""
    for i in numba.prange(A.shape[0]):
        for j in numba.prange(A.shape[1]):
            A[i,j] *= B[i,j]
    

            
def FFT_DI(U_0, step, wavelength, z, verbose=False, threads=1, planning_level='FFTW_MEASURE'):
    """ Uses the FFT-DI method to propagate U_0 (defined with a 
    stepsize as in step), using the given wavelength, a distance
    z. It returns an array Q of equal size and spacing to U_0.
    It can be asked to report on it's progress with verbose=True. 
    The threads argument changes the number of threads used by FFTW
    on the FFTs, and planning level changes the planning level of FFTW."""

    # It's worth padding the FFT to reach an efficient length
    # which is what this padding calculation does. We can zero-pad
    # everything saftely because we're going back to real space in the
    # end and it will make a big difference in performance
    # (seriously - depending on luck it can be like a factor of 100 speedup)
    padto = (next_fast_len(U_0.shape[0]*2-1),
             next_fast_len(U_0.shape[1]*2-1))

    if verbose:
        print('H being created')

    # 2 things to note.
    # 1) Generate arrays with pyfftw to get properly strided
    #    arrays that speed up computation (don't use np.empty!)
    # 2) Set up ffts before populating data, as creation of FFTW
    #    object is not guaranteed to leave the original data in place
    H = pyfftw.zeros_aligned(padto,dtype='complex64')
    fft = pyfftw.FFTW(H,H,axes=(0,1),threads=threads,
                      flags=(planning_level,))
    H[:U_0.shape[0],:U_0.shape[1]] = U_0

    # There was some comment to do with Simpson's theorem
    # that indicated you could improve the quality by multiplying
    # H by a particular pattern. I would like to do that here
    #
    # I've decided to not use Simpson's rule, because often we want
    # to simulate features close to the mesh size.
    #

    if verbose:
        print('H created, performing FFT')
    
    fft()

    if verbose:
        print('FFT complete on H, creating G')
        
    # Again, FFTs should be set up before data population
    G = pyfftw.empty_aligned(padto,dtype='complex64')
    ifft = pyfftw.FFTW(G,G,axes=(0,1),direction='FFTW_BACKWARD',
                       threads=threads, flags=(planning_level,))
    fft.update_arrays(G,G)

    populate_G(G, z, wavelength, step, np.array(U_0.shape))

    if verbose:
        print('G created, performing FFT')
    
    fft()

    if verbose:
        print('FFT complete on G, multiplying G and H')

    multiply_inplace(G,H)
    del H
    del fft
    gc.collect()

    if verbose:
        print('Multiplication complete, performing ifft')
    
    ifft()

    if verbose:
        print('IFFT Complete, copying out result to correctly sized array')

    # I do this so the reference to the large array that's 4x the size
    # of the result we want can be tossed out, leaving the program that
    # asked for it with more memory when the computation is done
    Q = np.empty(U_0.shape,dtype=np.complex128)
    Q[:,:] = G[U_0.shape[0]-1:2*U_0.shape[0]-1,
               U_0.shape[1]-1:2*U_0.shape[1]-1]
    del G
    del ifft
    gc.collect()
    
    return Q



